{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting the age of abalone with Linear Regression\n",
    "\n",
    "## Background\n",
    "\n",
    "Predicting the age of abalone from physical measurements.  The age of abalone is determined by cutting the shell through the cone, staining it, and counting the number of rings through a microscope -- a boring and time-consuming task.  Other measurements, which are easier to obtain, are used to predict the age.\n",
    "\n",
    "## Problem\n",
    "\n",
    "Build regression models by sex which can predict the number of rings."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fetching data from the Internet\n",
    "\n",
    "The training data is located on the website. Thanks for Greenplum's [external table](https://docs.vmware.com/en/VMware-Tanzu-Greenplum/6/greenplum-database/GUID-admin_guide-external-g-creating-and-using-web-external-tables.html), accessing CSV data from a HTTP server is easy.\n",
    "\n",
    "First, connect to the database. You may want to change the connect string like `postgresql://<user>@<host>:<port>/<db_name>` to fit your needs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext sql\n",
    "%sql postgresql://localhost/gpadmin"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, create the external table:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "vscode": {
     "languageId": "sql"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://localhost/gpadmin\n",
      "Done.\n",
      "Done.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "-- External Table\n",
    "DROP EXTERNAL TABLE IF EXISTS abalone_external;\n",
    "CREATE EXTERNAL WEB TABLE abalone_external(\n",
    "    sex text\n",
    "    , length float8\n",
    "    , diameter float8\n",
    "    , height float8\n",
    "    , whole_weight float8\n",
    "    , shucked_weight float8\n",
    "    , viscera_weight float8\n",
    "    , shell_weight float8\n",
    "    , rings integer -- target variable to predict\n",
    ") EXECUTE 'curl http://archive.ics.uci.edu/ml/machine-learning-databases/abalone/abalone.data'\n",
    "format 'CSV'\n",
    "(null as '?');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://localhost/gpadmin\n",
      "Done.\n",
      "12531 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "-- Create abalone table from an external table\n",
    "DROP TABLE IF EXISTS abalone;\n",
    "CREATE TABLE abalone AS (\n",
    "    SELECT ROW_NUMBER() OVER() AS id, *\n",
    "    FROM abalone_external\n",
    ") DISTRIBUTED BY (sex);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train-Test Set Split\n",
    "\n",
    "Before proceeding data exploration, let's split our dataset to train and test set.\n",
    "\n",
    "- Firstly, we fetch a random value between 0 and 1 to each row;\n",
    "- Then we create a percentile table that stores percentile values for each sex;\n",
    "- Finally, we join those 2 tables to obtain our training or test tables.\n",
    "\n",
    "But since Ordered-Set Aggregate Function is not yet supported by the current version, we will skip this step with GreenplumPython and implement it with SQL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * postgresql://localhost/gpadmin\n",
      "12531 rows affected.\n",
      "3 rows affected.\n",
      "3 rows affected.\n",
      "Done.\n",
      "10026 rows affected.\n",
      "Done.\n",
      "2508 rows affected.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%sql\n",
    "CREATE TEMP TABLE temp_abalone_label AS\n",
    "    (SELECT *, random() AS __samp_out_label FROM abalone);\n",
    "\n",
    "CREATE TEMP TABLE train_percentile_disc AS\n",
    "    (SELECT sex, percentile_disc(0.8) within GROUP (ORDER BY __samp_out_label) AS __samp_out_label\n",
    "    FROM temp_abalone_label GROUP BY sex);\n",
    "CREATE TEMP TABLE test_percentile_disc AS\n",
    "    (SELECT sex, percentile_disc(0.2) within GROUP (ORDER BY __samp_out_label) AS __samp_out_label\n",
    "    FROM temp_abalone_label GROUP BY sex);\n",
    "\n",
    "DROP TABLE IF EXISTS abalone_train;\n",
    "CREATE TABLE abalone_train AS\n",
    "    (SELECT temp_abalone_label.*\n",
    "        FROM temp_abalone_label\n",
    "        INNER JOIN train_percentile_disc\n",
    "        ON temp_abalone_label.__samp_out_label <= train_percentile_disc.__samp_out_label\n",
    "        AND temp_abalone_label.sex = train_percentile_disc.sex\n",
    "    );\n",
    "DROP TABLE IF EXISTS abalone_test;\n",
    "CREATE TABLE abalone_test AS\n",
    "    (SELECT temp_abalone_label.*\n",
    "        FROM temp_abalone_label\n",
    "        INNER JOIN test_percentile_disc\n",
    "        ON temp_abalone_label.__samp_out_label <= test_percentile_disc.__samp_out_label\n",
    "        AND temp_abalone_label.sex = test_percentile_disc.sex\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that these features could be supported by GreenplumPython in future release."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import preparation\n",
    "\n",
    "Connect to Greenplum database named `gpadmin`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import greenplumpython as gp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = gp.database(\"postgresql://localhost/gpadmin\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Exploration\n",
    "\n",
    "Get access to the existing table `abalone`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "abalone = db.create_dataframe(table_name=\"abalone\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspect the table:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "\t<tr>\n",
       "\t\t<th>id</th>\n",
       "\t\t<th>sex</th>\n",
       "\t\t<th>length</th>\n",
       "\t\t<th>diameter</th>\n",
       "\t\t<th>height</th>\n",
       "\t\t<th>whole_weight</th>\n",
       "\t\t<th>shucked_weight</th>\n",
       "\t\t<th>viscera_weight</th>\n",
       "\t\t<th>shell_weight</th>\n",
       "\t\t<th>rings</th>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>1</td>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>0.455</td>\n",
       "\t\t<td>0.365</td>\n",
       "\t\t<td>0.095</td>\n",
       "\t\t<td>0.514</td>\n",
       "\t\t<td>0.2245</td>\n",
       "\t\t<td>0.101</td>\n",
       "\t\t<td>0.15</td>\n",
       "\t\t<td>15</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>2</td>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>0.35</td>\n",
       "\t\t<td>0.265</td>\n",
       "\t\t<td>0.09</td>\n",
       "\t\t<td>0.2255</td>\n",
       "\t\t<td>0.0995</td>\n",
       "\t\t<td>0.0485</td>\n",
       "\t\t<td>0.07</td>\n",
       "\t\t<td>7</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>3</td>\n",
       "\t\t<td>F</td>\n",
       "\t\t<td>0.53</td>\n",
       "\t\t<td>0.42</td>\n",
       "\t\t<td>0.135</td>\n",
       "\t\t<td>0.677</td>\n",
       "\t\t<td>0.2565</td>\n",
       "\t\t<td>0.1415</td>\n",
       "\t\t<td>0.21</td>\n",
       "\t\t<td>9</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>4</td>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>0.44</td>\n",
       "\t\t<td>0.365</td>\n",
       "\t\t<td>0.125</td>\n",
       "\t\t<td>0.516</td>\n",
       "\t\t<td>0.2155</td>\n",
       "\t\t<td>0.114</td>\n",
       "\t\t<td>0.155</td>\n",
       "\t\t<td>10</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>5</td>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>0.33</td>\n",
       "\t\t<td>0.255</td>\n",
       "\t\t<td>0.08</td>\n",
       "\t\t<td>0.205</td>\n",
       "\t\t<td>0.0895</td>\n",
       "\t\t<td>0.0395</td>\n",
       "\t\t<td>0.055</td>\n",
       "\t\t<td>7</td>\n",
       "\t</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "---------------------------------------------------------------------------------------------------------------\n",
       " id | sex | length | diameter | height | whole_weight | shucked_weight | viscera_weight | shell_weight | rings \n",
       "----+-----+--------+----------+--------+--------------+----------------+----------------+--------------+-------\n",
       "  1 | M   |  0.455 |    0.365 |  0.095 |        0.514 |         0.2245 |          0.101 |         0.15 |    15 \n",
       "  2 | M   |   0.35 |    0.265 |   0.09 |       0.2255 |         0.0995 |         0.0485 |         0.07 |     7 \n",
       "  3 | F   |   0.53 |     0.42 |  0.135 |        0.677 |         0.2565 |         0.1415 |         0.21 |     9 \n",
       "  4 | M   |   0.44 |    0.365 |  0.125 |        0.516 |         0.2155 |          0.114 |        0.155 |    10 \n",
       "  5 | I   |   0.33 |    0.255 |   0.08 |        0.205 |         0.0895 |         0.0395 |        0.055 |     7 \n",
       "---------------------------------------------------------------------------------------------------------------\n",
       "(5 rows)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# SELECT * FROM abalone ORDER BY id LIMIT 5;\n",
    "\n",
    "abalone.order_by(\"id\")[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observe the distribution of data on different segments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "\t<tr>\n",
       "\t\t<th>count</th>\n",
       "\t\t<th>gp_segment_id</th>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>3921</td>\n",
       "\t\t<td>2</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>8610</td>\n",
       "\t\t<td>1</td>\n",
       "\t</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "-----------------------\n",
       " count | gp_segment_id \n",
       "-------+---------------\n",
       "  3921 |             2 \n",
       "  8610 |             1 \n",
       "-----------------------\n",
       "(2 rows)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# SELECT gp_execution_segment() AS gp_segment_id, COUNT(*)\n",
    "# FROM abalone\n",
    "# GROUP BY 1;\n",
    "\n",
    "import greenplumpython.builtins.functions as F\n",
    "\n",
    "abalone.assign(gp_segment_id=lambda _: gp.function(\"gp_execution_segment\")()).group_by(\n",
    "    \"gp_segment_id\"\n",
    ").apply(lambda _: F.count())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we already have table `abalone_train` ad `abalone_test` in the database, we can get access to them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "abalone_train = db.create_dataframe(table_name=\"abalone_train\")\n",
    "abalone_test = db.create_dataframe(table_name=\"abalone_test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning to Make Predictions\n",
    "\n",
    "### Training Model with Linear Regression\n",
    "#### Creation of training function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "\n",
    "**NOTE:** \n",
    "Since the `linreg_func` will be executed on Greenplum Database, all the python dependencies (`scikit-learn`/`numpy`/etc.) needs to be discoverable by the `plpython` extension. Please check [Greenplum PL/Python Language](https://docs.vmware.com/en/VMware-Tanzu-Greenplum/6/greenplum-database/GUID-analytics-pl_python.html) for more information.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "# CREATE TYPE linreg_type AS (\n",
    "#    col_nm text[]\n",
    "#    , coef float8[]\n",
    "#    , intercept float8\n",
    "#    , serialized_linreg_model bytea\n",
    "#    , created_dt text\n",
    "# );\n",
    "\n",
    "import dataclasses\n",
    "import datetime\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class LinregType:\n",
    "    col_nm: List[str]\n",
    "    coef: List[float]\n",
    "    intercept: float\n",
    "    serialized_linreg_model: bytes\n",
    "    created_dt: str\n",
    "\n",
    "\n",
    "# -- Create function\n",
    "# -- Need to specify the return type -> API will create the corresponding type in Greenplum to return a row\n",
    "# -- Will add argument to change language extensions, currently plpython3u by default\n",
    "\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "\n",
    "@gp.create_column_function\n",
    "def linreg_func(length: List[float], shucked_weight: List[float], rings: List[int]) -> LinregType:\n",
    "    X = np.array([length, shucked_weight]).T\n",
    "    y = np.array([rings]).T\n",
    "\n",
    "    # OLS linear regression with length, shucked_weight\n",
    "    linreg_fit = LinearRegression().fit(X, y)\n",
    "    linreg_coef = linreg_fit.coef_\n",
    "    linreg_intercept = linreg_fit.intercept_\n",
    "\n",
    "    # Serialization of the fitted model\n",
    "    serialized_linreg_model = pickle.dumps(linreg_fit, protocol=3)\n",
    "\n",
    "    return LinregType(\n",
    "        col_nm=[\"length\", \"shucked_weight\"],\n",
    "        coef=linreg_coef[0],\n",
    "        intercept=linreg_intercept[0],\n",
    "        serialized_linreg_model=serialized_linreg_model,\n",
    "        created_dt=str(datetime.datetime.now()),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apply `linreg_fitted` function to our train set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DROP TABLE IF EXISTS plc_linreg_fitted;\n",
    "# CREATE TABLE plc_linreg_fitted AS (\n",
    "#    SELECT\n",
    "#        a.sex\n",
    "#        , (plc_linreg_func(\n",
    "#            a.length_agg\n",
    "#            , a.shucked_weight_agg\n",
    "#            , a.rings_agg)\n",
    "#        ).*\n",
    "#    FROM (\n",
    "#        SELECT\n",
    "#            sex\n",
    "#            , ARRAY_AGG(length) AS length_agg\n",
    "#            , ARRAY_AGG(shucked_weight) AS shucked_weight_agg\n",
    "#            , ARRAY_AGG(rings) AS rings_agg\n",
    "#        FROM abalone_split\n",
    "#        WHERE split = 1\n",
    "#        GROUP BY sex\n",
    "#    ) a\n",
    "# ) DISTRIBUTED BY (sex);\n",
    "\n",
    "linreg_fitted = abalone_train.group_by(\"sex\").apply(\n",
    "    lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Take a look at models built:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "\t<tr>\n",
       "\t\t<th>sex</th>\n",
       "\t\t<th>col_nm</th>\n",
       "\t\t<th>coef</th>\n",
       "\t\t<th>intercept</th>\n",
       "\t\t<th>created_dt</th>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>F</td>\n",
       "\t\t<td>['length', 'shucked_weight']</td>\n",
       "\t\t<td>[25.9425725980866, -8.40468544064155]</td>\n",
       "\t\t<td>-0.145930582701281</td>\n",
       "\t\t<td>2023-02-23 20:58:21.485528</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>['length', 'shucked_weight']</td>\n",
       "\t\t<td>[15.4355302117113, 0.370827174250677]</td>\n",
       "\t\t<td>1.23206474957206</td>\n",
       "\t\t<td>2023-02-23 20:58:21.485653</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>['length', 'shucked_weight']</td>\n",
       "\t\t<td>[22.6450273736554, -6.02938247876165]</td>\n",
       "\t\t<td>0.569975195099637</td>\n",
       "\t\t<td>2023-02-23 20:58:21.489115</td>\n",
       "\t</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "------------------------------------------------------------------------------------------------------------------------------\n",
       " sex | col_nm                       | coef                                  | intercept          | created_dt                 \n",
       "-----+------------------------------+---------------------------------------+--------------------+----------------------------\n",
       " F   | ['length', 'shucked_weight'] | [25.9425725980866, -8.40468544064155] | -0.145930582701281 | 2023-02-23 20:58:21.485528 \n",
       " I   | ['length', 'shucked_weight'] | [15.4355302117113, 0.370827174250677] |   1.23206474957206 | 2023-02-23 20:58:21.485653 \n",
       " M   | ['length', 'shucked_weight'] | [22.6450273736554, -6.02938247876165] |  0.569975195099637 | 2023-02-23 20:58:21.489115 \n",
       "------------------------------------------------------------------------------------------------------------------------------\n",
       "(3 rows)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linreg_fitted[[\"sex\", \"col_nm\", \"coef\", \"intercept\", \"created_dt\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Summary of Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Using `ARRAY_APPEND`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "\t<tr>\n",
       "\t\t<th>sex</th>\n",
       "\t\t<th>col_nm2</th>\n",
       "\t\t<th>coef2</th>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>F</td>\n",
       "\t\t<td>length</td>\n",
       "\t\t<td>25.9425725980866</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>F</td>\n",
       "\t\t<td>shucked_weight</td>\n",
       "\t\t<td>-8.40468544064155</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>F</td>\n",
       "\t\t<td>intercept</td>\n",
       "\t\t<td>-0.145930582701292</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>length</td>\n",
       "\t\t<td>15.4355302117113</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>shucked_weight</td>\n",
       "\t\t<td>0.370827174250674</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>intercept</td>\n",
       "\t\t<td>1.23206474957206</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>length</td>\n",
       "\t\t<td>22.6450273736554</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>shucked_weight</td>\n",
       "\t\t<td>-6.02938247876165</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>intercept</td>\n",
       "\t\t<td>0.569975195099643</td>\n",
       "\t</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "-------------------------------------------\n",
       " sex | col_nm2        | coef2              \n",
       "-----+----------------+--------------------\n",
       " F   | length         |   25.9425725980866 \n",
       " F   | shucked_weight |  -8.40468544064155 \n",
       " F   | intercept      | -0.145930582701292 \n",
       " I   | length         |   15.4355302117113 \n",
       " I   | shucked_weight |  0.370827174250674 \n",
       " I   | intercept      |   1.23206474957206 \n",
       " M   | length         |   22.6450273736554 \n",
       " M   | shucked_weight |  -6.02938247876165 \n",
       " M   | intercept      |  0.569975195099643 \n",
       "-------------------------------------------\n",
       "(9 rows)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# SELECT\n",
    "#   sex,\n",
    "#   UNNEST(ARRAY_APPEND(col_nm, 'intercept')) AS col_nm2,\n",
    "#   UNNEST(ARRAY_APPEND(coef, intercept)) AS coef2\n",
    "# from linreg_fitted;\n",
    "\n",
    "unnest = gp.function(\"unnest\")\n",
    "array_append = gp.function(\"array_append\")\n",
    "\n",
    "linreg_fitted.assign(\n",
    "    col_nm2=lambda t: unnest(array_append(t[\"col_nm\"], \"intercept\")),\n",
    "    coef2=lambda t: unnest(array_append(t[\"coef\"], t[\"intercept\"])),\n",
    ")[[\"sex\", \"col_nm2\", \"coef2\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction\n",
    "\n",
    "#### Creation of prediction function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "@gp.create_function\n",
    "def linreg_pred_func(serialized_model: bytes, length: float, shucked_weight: float) -> float:\n",
    "    # Deserialize the serialized model\n",
    "    model = pickle.loads(serialized_model)\n",
    "    features = [length, shucked_weight]\n",
    "    # Predict the target variable\n",
    "    y_pred = model.predict([features])\n",
    "    return y_pred[0][0]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Join model dataframe and test set dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "linreg_test_fit = linreg_fitted.inner_join(\n",
    "    abalone_test,\n",
    "    cond=lambda t1, t2: t1[\"sex\"] == t2[\"sex\"],\n",
    "    self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"],\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Predict test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CREATE TABLE plc_linreg_pred_dot AS (\n",
    "#    SELECT\n",
    "#        test.id\n",
    "#        , test.sex\n",
    "#        , test.rings\n",
    "#        , plc_linreg_pred_dot_func(\n",
    "#            model.coef\n",
    "#            , model.intercept\n",
    "#            , ARRAY[length, shucked_weight]\n",
    "#        ) AS y_pred\n",
    "#    FROM\n",
    "#        (SELECT * FROM abalone_split WHERE split=0) AS test\n",
    "#        , plc_linreg_fitted AS model\n",
    "#    WHERE test.sex = model.sex\n",
    "# ) DISTRIBUTED BY (sex);\n",
    "\n",
    "\n",
    "linreg_pred = linreg_test_fit.assign(\n",
    "    rings_pred=lambda t: linreg_pred_func(\n",
    "        t[\"serialized_linreg_model\"],\n",
    "        t[\"length\"],\n",
    "        t[\"shucked_weight\"],\n",
    "    ),\n",
    ")[[\"id\", \"sex\", \"rings\", \"rings_pred\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "\t<tr>\n",
       "\t\t<th>id</th>\n",
       "\t\t<th>sex</th>\n",
       "\t\t<th>rings</th>\n",
       "\t\t<th>rings_pred</th>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>22</td>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>10</td>\n",
       "\t\t<td>7.1272324039624</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>45</td>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>4</td>\n",
       "\t\t<td>4.48001556958082</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>106</td>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>7</td>\n",
       "\t\t<td>6.35897875153222</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>107</td>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>8</td>\n",
       "\t\t<td>7.8444517211187</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>150</td>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>6</td>\n",
       "\t\t<td>6.27660952003415</td>\n",
       "\t</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "--------------------------------------\n",
       " id  | sex | rings | rings_pred       \n",
       "-----+-----+-------+------------------\n",
       "  22 | I   |    10 |  7.1272324039624 \n",
       "  45 | I   |     4 | 4.48001556958082 \n",
       " 106 | I   |     7 | 6.35897875153222 \n",
       " 107 | I   |     8 |  7.8444517211187 \n",
       " 150 | I   |     6 | 6.27660952003415 \n",
       "--------------------------------------\n",
       "(5 rows)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# SELECT * FROM plc_linreg_pred WHERE sex='I' ORDER BY id LIMIT 5;\n",
    "\n",
    "linreg_pred[lambda t: t[\"sex\"] == \"I\"].order_by(\"id\")[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation of model\n",
    "\n",
    "#### Creation of evaluation return type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# CREATE TYPE plc_linreg_eval_type AS (\n",
    "#    mae float8\n",
    "#    , mape float8\n",
    "#    , mse float8\n",
    "#    , r2_score float8\n",
    "# );\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class linreg_eval_type:\n",
    "    mae: float\n",
    "    mape: float\n",
    "    mse: float\n",
    "    r2_score: float"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Creation of evaluation function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "\n",
    "\n",
    "@gp.create_column_function\n",
    "def linreg_eval(y_actual: List[float], y_pred: List[float]) -> linreg_eval_type:\n",
    "    mae = mean_absolute_error(y_actual, y_pred)\n",
    "    mse = mean_squared_error(y_actual, y_pred)\n",
    "    r2_score_ = r2_score(y_actual, y_pred)\n",
    "\n",
    "    y_pred_f = np.array(y_pred, dtype=float)\n",
    "    mape = 100 * sum(abs(y_actual - y_pred_f) / y_actual) / len(y_actual)\n",
    "    return linreg_eval_type(mae, mape, mse, r2_score_)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluate our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "\t<tr>\n",
       "\t\t<th>sex</th>\n",
       "\t\t<th>mae</th>\n",
       "\t\t<th>mape</th>\n",
       "\t\t<th>mse</th>\n",
       "\t\t<th>r2_score</th>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>F</td>\n",
       "\t\t<td>2.14618814657571</td>\n",
       "\t\t<td>18.7708049967624</td>\n",
       "\t\t<td>8.68959806015986</td>\n",
       "\t\t<td>0.127677315539933</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>I</td>\n",
       "\t\t<td>1.21058504580913</td>\n",
       "\t\t<td>14.6104201739291</td>\n",
       "\t\t<td>3.30775897060488</td>\n",
       "\t\t<td>0.480766554461037</td>\n",
       "\t</tr>\n",
       "\t<tr>\n",
       "\t\t<td>M</td>\n",
       "\t\t<td>2.03231596723483</td>\n",
       "\t\t<td>19.0254559285428</td>\n",
       "\t\t<td>7.32630298752111</td>\n",
       "\t\t<td>0.199558502231448</td>\n",
       "\t</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "----------------------------------------------------------------------------------\n",
       " sex | mae              | mape             | mse              | r2_score          \n",
       "-----+------------------+------------------+------------------+-------------------\n",
       " F   | 2.14618814657571 | 18.7708049967624 | 8.68959806015986 | 0.127677315539933 \n",
       " I   | 1.21058504580913 | 14.6104201739291 | 3.30775897060488 | 0.480766554461037 \n",
       " M   | 2.03231596723483 | 19.0254559285428 | 7.32630298752111 | 0.199558502231448 \n",
       "----------------------------------------------------------------------------------\n",
       "(3 rows)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# SELECT\n",
    "#    sex\n",
    "#    , (linreg_eval(rings_agg, y_pred_agg)).*\n",
    "# FROM (\n",
    "#    SELECT\n",
    "#        sex\n",
    "#        , ARRAY_AGG(rings) AS rings_agg\n",
    "#        , ARRAY_AGG(y_pred) AS y_pred_agg\n",
    "#    FROM plc_linreg_pred\n",
    "#    GROUP BY sex\n",
    "# ) a\n",
    "\n",
    "\n",
    "linreg_pred.group_by(\"sex\").apply(lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
